import pandas as pd
import glob
import json
import os
import multiprocessing
import concurrent.futures
from joblib import Parallel, delayed
from tqdm import tqdm
import plotly.graph_objects as go
import plotly.express as px


def merge_stats(original, new):
    for stat_type, values in new.items():
        if stat_type not in original:
            original[stat_type] = {}
        for key, count in values.items():
            if key not in original[stat_type]:
                original[stat_type][key] = 0
            original[stat_type][key] += count

def glob_in_subdir(subdir_pattern):
    return glob.glob(subdir_pattern, recursive=True)

def parallel_glob(base_pattern, n_jobs=-1):
    # Split the base pattern into static and wildcard parts
    static_part, _, wildcard = base_pattern.rpartition('/')
    subdirs = glob.glob(static_part)

    # Construct patterns for each subdir
    subdir_patterns = [f"{subdir}/{wildcard}" for subdir in subdirs]

    # Search in parallel
    results = Parallel(n_jobs=n_jobs)(
        delayed(glob_in_subdir)(pattern) for pattern in tqdm(subdir_patterns, desc="Processing subdirectories"))

    # Flatten the list of lists
    all_files = [item for sublist in results for item in sublist]

    return all_files

def transform_to_dataframe_format(stats):
    rows = []
    for file_name, file_stats in stats.items():
        for stat_type, values in file_stats.items():
            for (category, attribute_value), count in values.items():
                row = {
                    'file_name': file_name,
                    'stat_type': stat_type,
                    'category': category,
                    'attribute_value': attribute_value,
                    'count': count
                }
                rows.append(row)
    return pd.DataFrame(rows)


def process_file(file_path):
    with open(file_path, 'r') as f:
        try:
            record = json.load(f)
        except json.JSONDecodeError:
            print(f"Failed to decode JSON from file: {file_path}")
            return None, None
        file_name = record['file'].split('/')[-1].split('_')[0]  # Simplified filename extraction
        file_stats = {'define': {}, 'status': {}}

        for annotation in record['frames']['annotations']:
            category = annotation['category']['code']

            for attribute in annotation['category']['attributes']:
                if 'define' in attribute['code']:
                    stat_type = 'define'
                elif 'status' in attribute['code']:
                    stat_type = 'status'
                else:
                    continue

                key = (category, attribute['value'])
                if key not in file_stats[stat_type]:
                    file_stats[stat_type][key] = 0
                file_stats[stat_type][key] += 1

        return file_name, file_stats


def generate_and_plot(stats_data, title):
    hierarchical_data = []

    # Add all file data
    for file_name, attributes in stats_data.items():
        for (category, value), count in attributes.items():
            hierarchical_data.append({"file_name": file_name, "category": category, "value": value, "count": count})

    df = pd.DataFrame(hierarchical_data)

    # Pivot to create MultiIndex
    df_pivot = df.pivot_table(index=['file_name', 'category', 'value'], values='count', aggfunc='sum')
    fig = px.sunburst(df, path=['file_name', 'category', 'value'])
    fig.update_layout(margin=dict(t=0, b=0, l=0, r=0), title=title)
    fig.show()

    return fig, df, df_pivot

if __name__ == '__main__':
    path = "/media/juni/T7 Shield/DeageonCCTV/인공지능_학습용_도심내_CCTV_영상_데이터셋/**/*.json"

    dataset = parallel_glob(path)

    # dataset = glob.glob("/media/juni/T7 Shield/DeageonCCTV/인공지능_학습용_도심내_CCTV_영상_데이터셋/21730619/*.json")

    print("files are located")
    results = Parallel(n_jobs=-1)(delayed(process_file)(file_path) for file_path in tqdm(dataset))

    stats = {}
    for directory_name, file_stats in results:
        if directory_name not in stats:
            stats[directory_name] = {}
        merge_stats(stats[directory_name], file_stats)
    print(stats)

    define_stats = {file: data['define'] for file, data in stats.items()}
    status_stats = {file: data['status'] for file, data in stats.items()}

    def_fig, def_data, df_multi_data = generate_and_plot(define_stats, "Object Definitions")
    def_data.to_csv("define_stats.csv", index=False)
    df_multi_data.to_csv("define_stats_multi.csv", index=False)
    stat_fig, stat_data, df_multi_data = generate_and_plot(status_stats, "Object Status")
    stat_data.to_csv("status_stats.csv", index=False)
    df_multi_data.to_csv("status_stats.csv_multi", index=False)